{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| [07_information_extraction/01_CRF实体识别模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/07_information_extraction/01_CRF实体识别模型.ipynb)  | 从头实现CRF实体识别模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/07_information_extraction/01_CRF实体识别模型.ipynb) |\n",
    "\n",
    "# CRF的中文实体识别模型\n",
    "\n",
    "非结构化的文本内容有很多丰富的信息，但找到相关的知识始终是一个具有挑战性的任务，命名实体识别可以帮忙找到其中核心的实体词，对新词发现也有挖掘扩展作用。\n",
    "\n",
    "前面我们用隐马尔可夫模型（HMM）自己尝试训练过一个分词器，其实 HMM 也可以用来训练命名实体识别器，但在本节，我们讲另外一个算法——条件随机场（CRF），来训练一个命名实体识别器。\n",
    "\n",
    "## 条件随机场（CRF）\n",
    "\n",
    "条件随机场（Conditional Random Fields，简称 CRF）是给定一组输入序列条件下另一组输出序列的条件概率分布模型，在自然语言处理中得到了广泛应用。\n",
    "\n",
    "首先，我们来看看什么是随机场。“随机场”的名字取的很玄乎，其实理解起来不难。随机场是由若干个位置组成的整体，当按照某种分布给每一个位置随机赋予一个值之后，其全体就叫做随机场。\n",
    "\n",
    "对于 CRF，我们给出准确的数学语言描述：设 X 与 Y 是随机变量，P(Y|X) 是给定 X 时 Y 的条件概率分布，若随机变量 Y 构成的是一个马尔科夫随机场，则称条件概率分布 P(Y|X) 是条件随机场。\n",
    "\n",
    "## 基于 CRF 的中文命名实体识别模型实现\n",
    "\n",
    "在常规的命名实体识别中，通用场景下最常提取的是时间、人物、地点及组织机构名，因此本模型也将提取以上四种实体。\n",
    "\n",
    "\n",
    "引入库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n",
      "Requirement already satisfied: sklearn_crfsuite in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (0.3.6)\r\n",
      "Requirement already satisfied: python-crfsuite>=0.8.3 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sklearn_crfsuite) (0.9.7)\r\n",
      "Requirement already satisfied: tqdm>=2.0 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sklearn_crfsuite) (4.59.0)\r\n",
      "Requirement already satisfied: tabulate in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sklearn_crfsuite) (0.8.9)\r\n",
      "Requirement already satisfied: six in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sklearn_crfsuite) (1.15.0)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install sklearn_crfsuite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import sklearn_crfsuite\n",
    "from sklearn_crfsuite import metrics\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CorpusProcess(object):\n",
    "\n",
    "    def __init__(self, data_dir='data/'):\n",
    "        \"\"\"初始化\"\"\"\n",
    "        self.train_corpus_path = data_dir + \"1980_01_rmrb.txt\"\n",
    "        self.process_corpus_path = data_dir + \"result_rmrb.txt\"\n",
    "        self._maps = {'t': 'T','nr': 'PER', 'ns': 'ORG','nt': 'LOC'}\n",
    "        \n",
    "    def read_corpus_from_file(self, file_path):\n",
    "        \"\"\"读取语料\"\"\"\n",
    "        f = open(file_path, 'r',encoding='utf-8')\n",
    "        lines = f.readlines()\n",
    "        f.close()\n",
    "        return lines\n",
    "    \n",
    "    def write_corpus_to_file(self, data, file_path):\n",
    "        \"\"\"写语料\"\"\"\n",
    "        f = open(file_path, 'wb')\n",
    "        f.write(data)\n",
    "        f.close()\n",
    "        \n",
    "    def q_to_b(self,q_str):\n",
    "        \"\"\"全角转半角\"\"\"\n",
    "        b_str = \"\"\n",
    "        for uchar in q_str:\n",
    "            inside_code = ord(uchar)\n",
    "            if inside_code == 12288:  # 全角空格直接转换\n",
    "                inside_code = 32\n",
    "            elif 65374 >= inside_code >= 65281:  # 全角字符（除空格）根据关系转化\n",
    "                inside_code -= 65248\n",
    "            b_str += chr(inside_code)\n",
    "        return b_str\n",
    "    \n",
    "    def b_to_q(self,b_str):\n",
    "        \"\"\"半角转全角\"\"\"\n",
    "        q_str = \"\"\n",
    "        for uchar in b_str:\n",
    "            inside_code = ord(uchar)\n",
    "            if inside_code == 32:  # 半角空格直接转化\n",
    "                inside_code = 12288\n",
    "            elif 126 >= inside_code >= 32:  # 半角字符（除空格）根据关系转化\n",
    "                inside_code += 65248\n",
    "            q_str += chr(inside_code)\n",
    "        return q_str\n",
    "    \n",
    "    def pre_process(self):\n",
    "        \"\"\"语料预处理 \"\"\"\n",
    "        lines = self.read_corpus_from_file(self.train_corpus_path)\n",
    "        new_lines = []\n",
    "        for line in lines:\n",
    "            words = self.q_to_b(line.strip()).split('  ')\n",
    "            pro_words = self.process_t(words)\n",
    "            pro_words = self.process_nr(pro_words)\n",
    "            pro_words = self.process_k(pro_words)\n",
    "            new_lines.append('  '.join(pro_words[1:]))\n",
    "        self.write_corpus_to_file(data='\\n'.join(new_lines).encode('utf-8'), file_path=self.process_corpus_path)\n",
    "    \n",
    "    def process_k(self, words):\n",
    "        \"\"\"处理大粒度分词,合并语料库中括号中的大粒度分词,类似：[国家/n  环保局/n]nt \"\"\"\n",
    "        pro_words = []\n",
    "        index = 0\n",
    "        temp = ''\n",
    "        while True:\n",
    "            word = words[index] if index < len(words) else ''\n",
    "            if '[' in word:\n",
    "                temp += re.sub(pattern='/[a-zA-Z]*', repl='', string=word.replace('[', ''))\n",
    "            elif ']' in word:\n",
    "                w = word.split(']')\n",
    "                temp += re.sub(pattern='/[a-zA-Z]*', repl='', string=w[0])\n",
    "                pro_words.append(temp+'/'+w[1])\n",
    "                temp = ''\n",
    "            elif temp:\n",
    "                temp += re.sub(pattern='/[a-zA-Z]*', repl='', string=word)\n",
    "            elif word:\n",
    "                pro_words.append(word)\n",
    "            else:\n",
    "                break\n",
    "            index += 1\n",
    "        return pro_words\n",
    "    \n",
    "    def process_nr(self, words):\n",
    "        \"\"\" 处理姓名，合并语料库分开标注的姓和名，类似：温/nr  家宝/nr\"\"\"\n",
    "        pro_words = []\n",
    "        index = 0\n",
    "        while True:\n",
    "            word = words[index] if index < len(words) else ''\n",
    "            if '/nr' in word:\n",
    "                next_index = index + 1\n",
    "                if next_index < len(words) and '/nr' in words[next_index]:\n",
    "                    pro_words.append(word.replace('/nr', '') + words[next_index])\n",
    "                    index = next_index\n",
    "                else:\n",
    "                    pro_words.append(word)\n",
    "            elif word:\n",
    "                pro_words.append(word)\n",
    "            else:\n",
    "                break\n",
    "            index += 1\n",
    "        return pro_words\n",
    "    \n",
    "    def process_t(self, words):\n",
    "        \"\"\"处理时间,合并语料库分开标注的时间词，类似： （/w  一九九七年/t  十二月/t  三十一日/t  ）/w   \"\"\"\n",
    "        pro_words = []\n",
    "        index = 0\n",
    "        temp = ''\n",
    "        while True:\n",
    "            word = words[index] if index < len(words) else ''\n",
    "            if '/t' in word:\n",
    "                temp = temp.replace('/t', '') + word\n",
    "            elif temp:\n",
    "                pro_words.append(temp)\n",
    "                pro_words.append(word)\n",
    "                temp = ''\n",
    "            elif word:\n",
    "                pro_words.append(word)\n",
    "            else:\n",
    "                break\n",
    "            index += 1\n",
    "        return pro_words\n",
    "    \n",
    "    def pos_to_tag(self, p):\n",
    "        \"\"\"由词性提取标签\"\"\"\n",
    "        t = self._maps.get(p, None)\n",
    "        return t if t else 'O'\n",
    "    \n",
    "    def tag_perform(self, tag, index):\n",
    "        \"\"\"标签使用BIO模式\"\"\"\n",
    "        if index == 0 and tag != 'O':\n",
    "            return 'B_{}'.format(tag)\n",
    "        elif tag != 'O':\n",
    "            return 'I_{}'.format(tag)\n",
    "        else:\n",
    "            return tag\n",
    "        \n",
    "    def pos_perform(self, pos):\n",
    "        \"\"\"去除词性携带的标签先验知识\"\"\"\n",
    "        if pos in self._maps.keys() and pos != 't':\n",
    "            return 'n'\n",
    "        else:\n",
    "            return pos\n",
    "        \n",
    "    def initialize(self):\n",
    "        \"\"\"初始化 \"\"\"\n",
    "        lines = self.read_corpus_from_file(self.process_corpus_path)\n",
    "        words_list = [line.strip().split('  ') for line in lines if line.strip()]\n",
    "        del lines\n",
    "        self.init_sequence(words_list)\n",
    "        \n",
    "    def init_sequence(self, words_list):\n",
    "        \"\"\"初始化字序列、词性序列、标记序列 \"\"\"\n",
    "        words_seq = [[word.split('/')[0] for word in words] for words in words_list]\n",
    "        pos_seq = [[word.split('/')[1] for word in words] for words in words_list]\n",
    "        tag_seq = [[self.pos_to_tag(p) for p in pos] for pos in pos_seq]\n",
    "        self.pos_seq = [[[pos_seq[index][i] for _ in range(len(words_seq[index][i]))]\n",
    "                        for i in range(len(pos_seq[index]))] for index in range(len(pos_seq))]\n",
    "        self.tag_seq = [[[self.tag_perform(tag_seq[index][i], w) for w in range(len(words_seq[index][i]))]\n",
    "                        for i in range(len(tag_seq[index]))] for index in range(len(tag_seq))]\n",
    "        self.pos_seq = [['un']+[self.pos_perform(p) for pos in pos_seq for p in pos]+['un'] for pos_seq in self.pos_seq]\n",
    "        self.tag_seq = [[t for tag in tag_seq for t in tag] for tag_seq in self.tag_seq]\n",
    "        self.word_seq = [['<BOS>']+[w for word in word_seq for w in word]+['<EOS>'] for word_seq in words_seq]   \n",
    "        \n",
    "    def extract_feature(self, word_grams):\n",
    "        \"\"\"特征选取\"\"\"\n",
    "        features, feature_list = [], []\n",
    "        for index in range(len(word_grams)):\n",
    "            for i in range(len(word_grams[index])):\n",
    "                word_gram = word_grams[index][i]\n",
    "                feature = {'w-1': word_gram[0], 'w': word_gram[1], 'w+1': word_gram[2],\n",
    "                           'w-1:w': word_gram[0]+word_gram[1], 'w:w+1': word_gram[1]+word_gram[2],\n",
    "                           # 'p-1': self.pos_seq[index][i], 'p': self.pos_seq[index][i+1],\n",
    "                           # 'p+1': self.pos_seq[index][i+2],\n",
    "                           # 'p-1:p': self.pos_seq[index][i]+self.pos_seq[index][i+1],\n",
    "                           # 'p:p+1': self.pos_seq[index][i+1]+self.pos_seq[index][i+2],\n",
    "                           'bias': 1.0}\n",
    "                feature_list.append(feature)\n",
    "            features.append(feature_list)\n",
    "            feature_list = []\n",
    "        return features \n",
    "    \n",
    "    def segment_by_window(self, words_list=None, window=3):\n",
    "        \"\"\"窗口切分\"\"\"\n",
    "        words = []\n",
    "        begin, end = 0, window\n",
    "        for _ in range(1, len(words_list)):\n",
    "            if end > len(words_list): break\n",
    "            words.append(words_list[begin:end])\n",
    "            begin = begin + 1\n",
    "            end = end + 1\n",
    "        return words\n",
    "    \n",
    "    def generator(self):\n",
    "        \"\"\"训练数据\"\"\"\n",
    "        word_grams = [self.segment_by_window(word_list) for word_list in self.word_seq]\n",
    "        features = self.extract_feature(word_grams)\n",
    "        return features, self.tag_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CRF_NER(object):\n",
    "\n",
    "    def __init__(self, data_dir='data/'):\n",
    "        \"\"\"初始化参数\"\"\"\n",
    "        self.algorithm = \"lbfgs\"\n",
    "        self.c1 =\"0.1\"\n",
    "        self.c2 = \"0.1\"\n",
    "        self.max_iterations = 100\n",
    "        self.model_path = data_dir + \"model.pkl\"\n",
    "        self.corpus = CorpusProcess(data_dir)  #Corpus 实例\n",
    "        self.corpus.pre_process()  #语料预处理\n",
    "        self.corpus.initialize()  #初始化语料\n",
    "        self.model = None\n",
    "\n",
    "    def initialize_model(self):\n",
    "        \"\"\"初始化\"\"\"\n",
    "        algorithm = self.algorithm\n",
    "        c1 = float(self.c1)\n",
    "        c2 = float(self.c2)\n",
    "        max_iterations = int(self.max_iterations)\n",
    "        self.model = sklearn_crfsuite.CRF(algorithm=algorithm, c1=c1, c2=c2,\n",
    "                                          max_iterations=max_iterations, all_possible_transitions=True)\n",
    "\n",
    "    def train(self):\n",
    "        \"\"\"训练\"\"\"\n",
    "        self.initialize_model()\n",
    "        x, y = self.corpus.generator()\n",
    "        x_train, y_train = x[500:], y[500:]\n",
    "        x_test, y_test = x[:500], y[:500]\n",
    "        self.model.fit(x_train, y_train)\n",
    "        labels = list(self.model.classes_)\n",
    "        labels.remove('O')\n",
    "        y_predict = self.model.predict(x_test)\n",
    "        metrics.flat_f1_score(y_test, y_predict, average='weighted', labels=labels)\n",
    "        sorted_labels = sorted(labels, key=lambda name: (name[1:], name[0]))\n",
    "        print(metrics.flat_classification_report(y_test, y_predict, labels=sorted_labels, digits=3))\n",
    "        self.save_model()\n",
    "\n",
    "    def predict(self, sentence):\n",
    "        \"\"\"预测\"\"\"\n",
    "        self.load_model()\n",
    "        u_sent = self.corpus.q_to_b(sentence)\n",
    "        word_lists = [['<BOS>']+[c for c in u_sent]+['<EOS>']]\n",
    "        word_grams = [self.corpus.segment_by_window(word_list) for word_list in word_lists]\n",
    "        features = self.corpus.extract_feature(word_grams)\n",
    "        y_predict = self.model.predict(features)\n",
    "        entity = ''\n",
    "        for index in range(len(y_predict[0])):\n",
    "            if y_predict[0][index] != 'O':\n",
    "                if index > 0 and y_predict[0][index][-1] != y_predict[0][index-1][-1]:\n",
    "                    entity += ' '\n",
    "                entity += u_sent[index]\n",
    "            elif entity[-1] != ' ':\n",
    "                entity += ' '\n",
    "        return entity\n",
    "\n",
    "    def load_model(self):\n",
    "        \"\"\"加载模型 \"\"\"\n",
    "        self.model = pickle.load(open(self.model_path, 'rb'))\n",
    "\n",
    "    def save_model(self):\n",
    "        \"\"\"保存模型\"\"\"\n",
    "        pickle.dump(self.model, open(self.model_path, 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.CRF_NER at 0x7fd21cd007c0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ner = CRF_NER(data_dir='data/')\n",
    "ner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass labels=['B_LOC', 'I_LOC', 'B_ORG', 'I_ORG', 'B_PER', 'I_PER', 'B_T', 'I_T'] as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error\n",
      "  warnings.warn(f\"Pass {args_msg} as keyword args. From version \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       B_LOC      0.929     0.492     0.644       266\n",
      "       I_LOC      0.807     0.420     0.552      1203\n",
      "       B_ORG      0.861     0.670     0.754       682\n",
      "       I_ORG      0.844     0.603     0.703       997\n",
      "       B_PER      0.929     0.330     0.487       440\n",
      "       I_PER      0.949     0.364     0.526       824\n",
      "         B_T      0.925     0.831     0.875       444\n",
      "         I_T      0.953     0.902     0.927      1099\n",
      "\n",
      "   micro avg      0.892     0.588     0.709      5955\n",
      "   macro avg      0.900     0.576     0.683      5955\n",
      "weighted avg      0.889     0.588     0.689      5955\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = ner.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'新华社 北京 十二月三十一日  中央人民广播电台  刘振英  新华社  张宿堂  今天  一九九七年 '"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ner.predict('新华社北京十二月三十一日电(中央人民广播电台记者刘振英、新华社记者张宿堂)今天是一九九七年的最后一天。')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'一九四九年  毛泽东  中国 '"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ner.predict(u'一九四九年，国庆节，毛泽东同志在天安门城楼上宣布中国共产党从此站起来了！')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
